using Random;
using CPUTime;

include("peps.jl");
include("boundeddists.jl");
include("samplingrules.jl");
include("reco_stop_pairs.jl");

# Run the learning algorithm, paramterised by a sampling rule
# The stopping and recommendation rules are common
#
# βs must be a list of thresholds *in increasing order*
# For non-anytime algorithms, βs should be a singleton

function runit(seed, is, pep, δs, Tau_max)
    sr, rsp = is;
    gap_sr = typeof(sr) == LUCB;

    K = nanswers(pep);

    # Get thresholds
    βs = get_threshold(rsp.threshold, δs, 2, K, 2);

    rng = MersenneTwister(seed);

    N = zeros(Int64, K);        # counts
    S = zeros(K);               # sum of samples
    Xs = [[] for k in 1:K]

    baseline = CPUtime_us();

    # pull each arm once
    n0 = 1;
    for k in 1:K
        for i in 1:n0
            _X = sample(rng, getdist(pep, k));
            S[k] += _X;
            N[k] += 1;
            append!(Xs[k], [_X]);
        end
    end

    state = start(sr, N);
    R = Tuple{Int64, Array{Int64,1}, UInt64}[]; # collect return values

    lazy_astar = argmax(S ./ N);
    lazy_aalt = lazy_astar;

    while true
        t = sum(N);

        # emp. estimates
        hμ = S ./ N;

        if gap_sr
            # invoke sampling rule
            astar, k, ucb = nextsample(state, pep, hμ, N, Xs, βs[1](t));
            ks = [astar, k];

            # test stopping criterion
            while ucb <= 0
                popfirst!(βs);
                push!(R, (astar, copy(N), CPUtime_us() - baseline));
                if isempty(βs)
                    return R;
                end
            end
        else
            if typeof(rsp) == GLRT || (t == n0 * K) || compute_GLR(rsp, t, lazy_astar != argmax(hμ))
                # test stopping criterion
                Zs, (aalt, _), (astar, ξ) = glrt(pep, N, hμ, Xs);

                while stopping_criterion(Zs, βs[1], N, astar, aalt)
                    popfirst!(βs);
                    push!(R, (astar, copy(N), CPUtime_us() - baseline));
                    if isempty(βs)
                        return R;
                    end
                end
                lazy_astar = astar;
                lazy_aalt = aalt;
            else
                astar, ξ, aalt = lazy_astar, hμ, lazy_aalt;
            end
            # invoke sampling rule
            k = nextsample(state, pep, astar, aalt, ξ, N, Xs, rng);
            ks = [k];
        end

        for k in ks
            _X = sample(rng, getdist(pep, k));
            S[k] += _X;
            N[k] += 1;
            append!(Xs[k], [_X]);
            t += 1;
        end

        if t > Tau_max
            @warn "Finite-time horizon Tau_max = $(Tau_max) met without stopping for " * abbrev(sr) * ". Increase this if hard problem.";
            push!(R, (astar, copy(N), CPUtime_us() - baseline));
            return R;
        end
    end
end

function stopping_criterion(Zs, β, N, astar, aalt)
    if abbrev(β) == "FTT"
        K = length(N);
        stop = true;
        for a in 1:K
            if a != astar
                cdt = Zs[a] > β(N[astar], N[a]);
                stop = stop && cdt;
            end
        end
    else
        t = sum(N);
        stop = Zs[aalt] > β(t);
    end
    return stop;
end
